The present workshop is based on the tutorial "Modèle Hiérarchique avec Stan" of Matthieu Authier & Eric Parent.
In this workshop, we use flowering date data collected between 1978 and 2016 and published in Wenden et al. (2016). Data can be downloaded in this driad repository. This dataset contains flowering dates of 9,691 indivuals/clones of Prunus avium in Europe.
Below is a figure from Wenden et al. (2016) showing the 25 studied sites in 11 European countries. Flowering dates were recorded in 12 sites. Size of the circle is proportional to the number of cultivars recorded in each site.
dataSakura <- read_excel("../data/Sweet_cherry_phenology_data_1978-2015.xlsx", sheet = 1)
dataSakura <- dataSakura[1:1000,] %>% # keep only 1000 individuals (to shorten the model running time)
dplyr::rename(Flowering="Full Flowering") %>% # response variable: the date of flowering
filter(!is.na(Flowering),!is.na(Plantation)) %>% # remove missing values
dplyr::mutate(Age = Year - Plantation, # create the "age" variable
Age = ifelse(Age > 14, 14, Age)) %>%
dplyr::select(Site,Age,Cultivar,Flowering) # select the columns we are going to use
# Show the first 10 lines of the dataset
dataSakura[1:10,] %>%
kable(digits=3) %>%
kable_styling(font_size=12,
bootstrap_options = c("striped","hover", "condensed"), full_width = F)
| Site | Age | Cultivar | Flowering |
|---|---|---|---|
| Montauban | 10 | Burlat | 73 |
| Montauban | 10 | Regina | 92 |
| Montauban | 10 | Satin Sumele | 87 |
| Montauban | 10 | Summit | 93 |
| Montauban | 10 | Bellise Bedel | 83 |
| Montauban | 9 | Ferlizac | 78 |
| Montauban | 10 | Ferlizac | 84 |
| Montauban | 9 | Fermina | 90 |
| Montauban | 10 | Fermina | 90 |
| Montauban | 9 | Fertille | 83 |
Variation in flowering date with tree age (14 age classes):
dataSakura %>%
group_by(Age) %>%
summarize(Effectif = n(),
Flowering_mean = round(mean(Flowering, na.rm = TRUE), 1)) %>%
kable() %>%
kable_styling(font_size=12,
bootstrap_options = c("striped","hover", "condensed"), full_width = F)
| Age | Effectif | Flowering_mean |
|---|---|---|
| 1 | 2 | 86.5 |
| 2 | 81 | 91.7 |
| 3 | 91 | 92.6 |
| 4 | 103 | 93.9 |
| 5 | 99 | 90.8 |
| 6 | 114 | 93.9 |
| 7 | 104 | 92.7 |
| 8 | 82 | 92.5 |
| 9 | 63 | 94.1 |
| 10 | 53 | 90.9 |
| 11 | 25 | 96.2 |
| 12 | 26 | 93.5 |
| 13 | 11 | 90.6 |
| 14 | 5 | 108.2 |
Variation in flowering date by site (12 sites):
dataSakura %>%
group_by(Site) %>%
summarize(Effectif = n(),
Flowering_mean = round(mean(Flowering, na.rm = TRUE), 1)) %>%
kable() %>%
kable_styling(font_size=12,
bootstrap_options = c("striped","hover", "condensed"), full_width = F)
| Site | Effectif | Flowering_mean |
|---|---|---|
| Balandran | 198 | 91.2 |
| Carpentras | 322 | 91.4 |
| Montauban | 79 | 88.9 |
| St Epain | 221 | 97.5 |
| Toulenne | 39 | 94.1 |
We start with a simple model in which we aim to model the flowering date \(y_{ijk}\) of each individual \(i\) as a function of its age \(j\) and its site \(k\), such as:
\[\begin{align} y_{ijk} & \sim \mathcal{N}(\mu_{ijk},\sigma) \tag*{Likelihood}\\[3pt] \mu_{ijk} & = \beta_0 + \alpha_j + \delta_k \tag*{Linear model}\\[3pt] \beta_0 & \sim \mathcal{N}(\mu_y, 10) \tag*{Global intercept prior}\\[3pt] \alpha_j & \sim \mathcal{N}(0,\sigma_{age})\tag*{Distribution of varying age intercepts}\\[3pt] \alpha_k & \sim \mathcal{N}(0,\sigma_{site}) \tag*{Distribution of varying site intercepts}\\ \end{align}\]We want to specify the priors for \(\sigma\), \(\sigma_{age}\) and \(\sigma_{site}\). For that, we partition the total variance \(\sigma_{tot}\) as follows:
\[\begin{align} \sigma^2_{tot} & = \sigma^2 + \sigma^2_{age} + \sigma^2_{site}\\[3pt] \sigma & = \sigma_{tot} \times \sqrt{\pi_1}\\[3pt] \sigma_{age} & = \sigma_{tot} \times \sqrt{\pi_2}\\[3pt] \sigma_{site} & = \sigma_{tot} \times \sqrt{\pi_3}\\[3pt] \end{align}\]with \(\sum_{l=1}^3\pi_l = 1\) (see the unit simplex in stan) and \(\sigma_{tot} \sim \mathcal{S}^+(0,1,3)\) (student prior with 3 degrees of freedom).
This model is an ANOVA with 2 factors (age & site).
/*----------------------- Data --------------------------*/
data {
int<lower = 1> n_obs; // Total number of observations
int<lower = 1> n_age; // Number of different age classes
int<lower = 1> n_site; // Number of different sites
vector[n_obs] FLOWERING; // Response variable (flowering dates)
int<lower = 1, upper = n_age> AGE[n_obs]; // Age variable
int<lower = 1, upper = n_site> SITE[n_obs]; // Site variable
real prior_location_beta0;
real<lower = 0.0> prior_scale_beta0;
}
/*----------------------- Parameters --------------------------*/
parameters {
simplex[3] pi; // unit complex specifying that the sum of its elements equal to one.
real beta0; // global intercept
real<lower = 0.0> sigma_tot; // Total standard deviation
vector[n_age] alpha; // Age intercepts
vector[n_site] delta; // Site intercepts
}
/*------------------- Transformed Parameters --------------------*/
transformed parameters {
real sigma; // Residual standard deviation
real sigma_age; // Standard deviation of the age intercepts
real sigma_site; // Standard deviation of the site intercepts
vector[n_obs] mu; // linear predictor
sigma = sqrt(pi[1]) * sigma_tot;
sigma_age = sqrt(pi[2]) * sigma_tot;
sigma_site = sqrt(pi[3]) * sigma_tot;
mu = rep_vector(beta0, n_obs) + alpha[AGE] + delta[SITE];
}
/*----------------------- Model --------------------------*/
model {
// Priors
beta0 ~ normal(prior_location_beta0, prior_scale_beta0); // Prior of the global intercept
sigma_tot ~ student_t(3, 0.0, 1.0); // Prior of the total standard deviation
alpha ~ normal(0.0, sigma_age); // Prior of the age intercepts
delta ~ normal(0.0, sigma_site); // Prior of the site intercepts
// Likelihood
FLOWERING ~ normal(mu, sigma);
}
/*----------------- Generated Quantities ------------------*/
generated quantities {
vector[n_obs] log_lik; // Log-likelihood
vector[n_obs] y_rep; // posterior predictive check
for(i in 1:n_obs) {
log_lik[i] = normal_lpdf(FLOWERING[i]| mu[i], sigma); // log probability density function
y_rep[i] = normal_rng(mu[i], sigma); // prediction from posterior
}
}
Input data:
list.baseline.model = list(n_obs = nrow(dataSakura),
n_age = length(unique(dataSakura$Age)),
n_site = length(unique(dataSakura$Site)),
FLOWERING = dataSakura$Flowering,
AGE = dataSakura$Age,
SITE = as.numeric(factor(dataSakura$Site, levels = unique(dataSakura$Site))),
prior_location_beta0 = mean(dataSakura$Flowering),
prior_scale_beta0 = 10)
Sampling:
fit.baseline.model <- sampling(baseline.model,
data = list.baseline.model,
pars = c("beta0", "alpha", "delta",
"sigma", "sigma_age", "sigma_site", "sigma_tot",
"pi", "y_rep", "log_lik"),
save_warmup = F,
iter = 2000,
chains = 4, cores = 4,thin=1)
loo.baseline.model <- loo::loo(fit.baseline.model) # to compare model predictive ability
stan_rhat(fit.baseline.model)
That's ok!
It would be better to check also the chain convergence, the effective sammple size and the autocorrelation, but we will skip these steps here! More details here: https://mc-stan.org/bayesplot/articles/visual-mcmc-diagnostics.html.
Let's look at the parameter estimates of the standard deviation for the sites (\(\sigma_{site}\)), the age (\(\sigma_{age}\)), the residuals (\(\sigma\)), the total standard deviation (\(\sigma_{tot}\)) and the relative importance of each variance component (i.e. proportion of the total variance explained by each component, i.e. site, age and residuals).
Here is the coefficients table:
print(fit.baseline.model, digits = 3, pars = c("beta0", "sigma", "sigma_age", "sigma_site", "sigma_tot", "pi"))
## Inference for Stan model: 62d24212c46de9892dd578dab95b34f3.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## beta0 92.774 0.074 1.892 88.917 91.690 92.792 93.866 96.443 654 1.009
## sigma 7.079 0.003 0.175 6.743 6.958 7.076 7.197 7.430 4107 1.000
## sigma_age 1.530 0.023 0.676 0.597 1.050 1.409 1.860 3.205 839 1.007
## sigma_site 3.772 0.040 1.340 1.956 2.824 3.515 4.440 7.027 1145 1.007
## sigma_tot 8.270 0.024 0.757 7.367 7.772 8.105 8.548 10.183 1016 1.006
## pi[1] 0.747 0.003 0.109 0.486 0.686 0.764 0.828 0.901 1163 1.007
## pi[2] 0.040 0.001 0.036 0.005 0.016 0.029 0.051 0.138 945 1.007
## pi[3] 0.213 0.003 0.109 0.068 0.131 0.190 0.271 0.484 1205 1.008
##
## Samples were drawn using NUTS(diag_e) at Sat Jun 5 07:42:25 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
Plotting the credible intervals of the parameters of interest
First, the relative importance of each variance component: \(\pi_1\) for the residuals, \(\pi_2\) for the age and \(\pi_3\) for the sites.
fit.baseline.model %>% mcmc_intervals(regex_pars = "^pi",
prob=0.95,
prob_outer=0.99,
point_est = "median") + theme_bw() +
theme(axis.text = element_text(size=16))
The standard deviations:
lower <- function(x, alpha = 0.8) { coda::HPDinterval(coda::as.mcmc(x), prob = alpha)[1] }
upper <- function(x, alpha = 0.8) { coda::HPDinterval(coda::as.mcmc(x), prob = alpha)[2] }
get_summary <- function(x, alpha = 0.8) { c(mean(x), sd(x), coda::HPDinterval(coda::as.mcmc(x), prob = alpha)) }
summary_anova <- as.data.frame(
do.call('rbind', lapply(c("sigma", "sigma_age", "sigma_site","sigma_tot"),
function(param) {
get_summary(as.numeric(rstan::extract(fit.baseline.model, param)[[1]]))
}
)
)
)
names(summary_anova) <- c("mean", "se", "lower", "upper")
summary_anova$component <- c("residual", "age", "site","total")
summary_anova %>%
mutate(component = factor(component, levels = c("residual", "age", "site","total")[order(mean)])) %>%
ggplot(aes(x = component, y = mean)) +
geom_linerange(aes(x = component, ymin = lower, ymax = upper)) +
geom_point(size = 2) +
ylab("Estimate") + xlab("Source of variation") +
coord_flip() +
theme_bw()
# For the graph, we could also have used:
# fit.baseline.model %>% mcmc_intervals(regex_pars = "^sigma",
# prob=0.95,
# prob_outer=0.99,
# point_est = "median") + theme_bw() +
# theme(axis.text = element_text(size=16))
As sites seem to considerably impact the total variance, we can display the parameters \(\delta_k\):
freq <- dataSakura %>%
group_by(Site) %>%
summarize(effectif = n(),
flowering = mean(Flowering, na.rm = TRUE)
)
moyenne = with(freq, sum(effectif * flowering) / sum(effectif))
freq[6, 1] = "Average"
freq[6, 2] = mean(freq$effectif)
freq[6, 3] = moyenne
post_site <- as.data.frame(t(apply(matrix(rep(rstan::extract(fit.baseline.model, "beta0")$beta0, each = length(unique(dataSakura$Site))), ncol = length(unique(dataSakura$Site)), byrow = TRUE) + rstan::extract(fit.baseline.model, "delta")$delta, 2, get_summary)))
names(post_site) <- c("mean", "se", "lower", "upper")
post_site$where <- c(unique(dataSakura$Site)) #, "Average")
post_site <- cbind(post_site,
freq[match(post_site$where, freq$Site), c('flowering', 'effectif')]
)
post_site %>%
mutate(where = factor(where, levels = c(unique(dataSakura$Site), "Average")[order(mean)])) %>%
ggplot(aes(x = where, y = mean)) +
geom_linerange(aes(x = where, ymin = lower, ymax = upper)) +
geom_point(size = 2) +
geom_point(aes(x = where, y = flowering, size = effectif), color = 'red', alpha = 0.3) +
scale_y_continuous(name = "Estimate (days)", breaks = 95 -10:10) +
xlab("Site") +
coord_flip() +
theme_bw()
ppc_dens_overlay(y = dataSakura$Flowering,
as.matrix(fit.baseline.model, pars = "y_rep")[1:50, ]) +
theme_bw() +
theme(legend.text=element_text(size=25), legend.title=element_text(size=18),
axis.text = element_text(size=18), legend.position = c(0.8,0.6))
To write our own likelihood function, we will use the target += function, which allows to directly increments the log density of the posterior up to an additive constant.
From Bob Carpenter's comment in Stackoverflow: target += u adds u to the target log density. The target density is the density from which the sampler samples and it needs to be equal to the joint density of all the parameters given the data up to a constant.
But first, we have to mathematically write our own likelihood!
Let's assume there are two genetically distinct types of individuals that can be differentiated based on their flowering date : early or late flowering. We want to identify individuals with the genetic potential of flowering earlier.
Let's \(p\) be the probability of a late individual \(i\) and \(1-p\) the probability of an early individual.
Then, the model becomes: \[\begin{align} y_{ijk} & \sim \mathcal{N}(\mu_{ijk}^l,\sigma) \tag*{Likelihood}\\[3pt] \mu_{ijk}^l & = \beta_l + \alpha_j + \delta_k \tag*{Linear model}\\[3pt] \\ \end{align}\]with \(l \in \{1,2\}\). Therefore, we introduced a supplementary discrete latent variable: \(z_{ijk} \sim \mathcal{B}(p)\) which models the state (early \(l=1\) or late \(l=2\)) according to probability \(p\).
As a consequence, the likelihood is: \[\begin{align} \mathcal{L}(y_{ijk}) = (1-p) \times \mathcal{N}(\beta_1 + \alpha_j + \delta_k, \sigma) + p \times \mathcal{N}(\beta_2 + \alpha_j + \delta_k, \sigma) \\ \end{align}\] Thus, the log-likelihood is: \[\begin{align} l(y_{ijk}) = \log{[(1-p) \times \mathcal{N}(\beta_1 + \alpha_j + \delta_k, \sigma) + p \times \mathcal{N}(\beta_2 + \alpha_j + \delta_k, \sigma)]} \\ \end{align}\]Due to the presence of a discrete variable \(z_{ijk}\), the likelihood is then implemented using the target function instead of \(\sim\).
First way of doing it, we can directly write the likelihood function in the model block:
/*----------------------- Data --------------------------*/
data {
int<lower = 1> n_obs; // Total number of observations
int<lower = 1> n_age; // Number of different age classes
int<lower = 1> n_site; // Number of different sites
vector[n_obs] FLOWERING; // Response variable (flowering dates)
int<lower = 1, upper = n_age> AGE[n_obs]; // Age variable
int<lower = 1, upper = n_site> SITE[n_obs]; // Site variable
real prior_location_beta0;
real<lower = 0.0> prior_scale_beta0;
real prior_location_diff;
real<lower = 0.0> prior_scale_diff;
}
/*----------------------- Parameters --------------------------*/
parameters {
real<lower = 0.0, upper = 1.0> p; // proba (early or late flowering)
simplex[3] pi; // unit complex specifying that the sum of its elements equal to one.
real beta0; // global intercept
real<lower = 0.0> sigma_tot; // Total standard deviation
vector[n_age] alpha; // Age intercepts
vector[n_site] delta; // Site intercepts
real<lower = 0> difference; // difference between beta_1 and beta_2
}
/*------------------- Transformed Parameters --------------------*/
transformed parameters {
real sigma; // Residual standard deviation
real sigma_age; // Standard deviation of the age intercepts
real sigma_site; // Standard deviation of the site intercepts
vector[n_obs] mu[2]; // linear predictor
vector[2] beta;
beta[1] = prior_location_beta0 + beta0 * prior_scale_beta0;
beta[2] = beta[1] + difference;
sigma = sqrt(pi[1]) * sigma_tot;
sigma_age = sqrt(pi[2]) * sigma_tot;
sigma_site = sqrt(pi[3]) * sigma_tot;
mu[1] = rep_vector(beta[1], n_obs) + alpha[AGE] + delta[SITE];
mu[2] = rep_vector(beta[2], n_obs) + alpha[AGE] + delta[SITE];
}
/*----------------------- Model --------------------------*/
model {
// Priors
beta0 ~ normal(0.0, 1.0);
difference ~ normal(prior_location_diff, prior_scale_diff);
sigma_tot ~ student_t(3, 0.0, 1.0); // Prior of the total standard deviation
alpha ~ normal(0.0, sigma_age); // Prior of the age intercepts
delta ~ normal(0.0, sigma_site); // Prior of the site intercepts
// Our own likelihood
for(i in 1:n_obs) {
target += log_sum_exp(log1m(p) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma), log(p) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
}
}
/*----------------------- Extracting the log-likelihood --------------------------*/
generated quantities {
vector[n_obs] log_lik; // Log-likelihood
for(i in 1:n_obs) {
log_lik[i] = log_sum_exp(log1m(p) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma), log(p) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
}
}
Second way of doing it, custom-functions can be implemented by using the function block:
/*--------------------- Functions ------------------------*/
functions {
// for the estimating the log probability density function (lpdf)
real TwoGaussianMixture_lpdf(real y, real prob, vector location, real scale) {
real log_pdf[2];
log_pdf[1] = log1m(prob) + normal_lpdf(y| location[1], scale);
log_pdf[2] = log(prob) + normal_lpdf(y| location[2], scale);
return log_sum_exp(log_pdf);
}
// for the generated quantities (prediction)
real TwoGaussianMixture_rng(real prob, vector location, real scale) {
int z;
z = bernoulli_rng(prob);
return z ? normal_rng(location[2], scale) : normal_rng(location[1], scale);
}
}
/*----------------------- Data --------------------------*/
data {
int<lower = 1> n_obs; // Total number of observations
int<lower = 1> n_age; // Number of different age classes
int<lower = 1> n_site; // Number of different sites
vector[n_obs] FLOWERING; // Response variable (flowering dates)
int<lower = 1, upper = n_age> AGE[n_obs]; // Age variable
int<lower = 1, upper = n_site> SITE[n_obs]; // Site variable
real prior_location_beta0;
real<lower = 0.0> prior_scale_beta0;
real prior_location_diff;
real<lower = 0.0> prior_scale_diff;
}
/*----------------------- Parameters --------------------------*/
parameters {
real<lower = 0.0, upper = 1.0> p; // proba (early or late flowering)
simplex[3] pi; // unit complex specifying that the sum of its elements equal to one.
real beta0; // global intercept
real<lower = 0.0> sigma_tot; // Total standard deviation
vector[n_age] alpha; // Age intercepts
vector[n_site] delta; // Site intercepts
real<lower = 0> difference; // difference between beta_1 and beta_2
}
/*------------------- Transformed Parameters --------------------*/
transformed parameters {
real sigma; // Residual standard deviation
real sigma_age; // Standard deviation of the age intercepts
real sigma_site; // Standard deviation of the site intercepts
vector[n_obs] mu[2]; // linear predictor
vector[2] beta;
beta[1] = prior_location_beta0 + beta0 * prior_scale_beta0;
beta[2] = beta[1] + difference;
sigma = sqrt(pi[1]) * sigma_tot;
sigma_age = sqrt(pi[2]) * sigma_tot;
sigma_site = sqrt(pi[3]) * sigma_tot;
mu[1] = rep_vector(beta[1], n_obs) + alpha[AGE] + delta[SITE];
mu[2] = rep_vector(beta[2], n_obs) + alpha[AGE] + delta[SITE];
}
/*----------------------- Model --------------------------*/
model {
// Priors
beta0 ~ normal(0.0, 1.0);
difference ~ normal(prior_location_diff, prior_scale_diff);
sigma_tot ~ student_t(3, 0.0, 1.0); // Prior of the total standard deviation
alpha ~ normal(0.0, sigma_age); // Prior of the age intercepts
delta ~ normal(0.0, sigma_site); // Prior of the site intercepts
// Our own likelihood
for(i in 1:n_obs) {
FLOWERING[i] ~ TwoGaussianMixture(p, to_vector(mu[1:2, i]), sigma);
}
}
/*----------------- Generated Quantities ------------------*/
generated quantities {
vector[n_obs] log_lik; // Log-likelihood
vector[n_obs] y_rep; // posterior predictive check
for(i in 1:n_obs) {
log_lik[i] = log_sum_exp(log1m(p) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma),
log(p) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
y_rep[i] = TwoGaussianMixture_rng(p, to_vector(mu[1:2, i]), sigma);
}
}
Input data:
listMix.stan = list(n_obs = nrow(dataSakura),
n_age = length(unique(dataSakura$Age)),
n_site = length(unique(dataSakura$Site)),
FLOWERING = dataSakura$Flowering,
AGE = dataSakura$Age,
SITE = as.numeric(factor(dataSakura$Site, levels = unique(dataSakura$Site))),
prior_location_beta0 = mean(dataSakura$Flowering),
prior_scale_beta0 = 10,
prior_location_diff = 7,
prior_scale_diff = 3)
Sampling:
fit.mixTarget.model <- sampling(mixTarget.code,
data = listMix.stan,
pars = c("p", "beta", "alpha", "delta",
"sigma", "sigma_age", "sigma_site", "sigma_tot",
"pi", "difference", "y_rep", "log_lik"),
save_warmup = F,
iter = 2000,
chains = 4,
cores = 4,
thin=1)
## Warning: There were 1 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems
Checking parameter convergence:
stan_rhat(fit.mixTarget.model)
Parameter estimations:
print(fit.mixTarget.model, digits = 3, pars = c("p", "beta", "sigma", "sigma_age", "sigma_site", "sigma_tot", "pi", "difference"))
## Inference for Stan model: b992c238037d6f657254e4dba82efd70.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## p 0.466 0.002 0.082 0.280 0.424 0.470 0.510 0.604 2260 1.000
## beta[1] 88.696 0.071 1.976 84.939 87.459 88.642 89.885 92.761 781 1.012
## beta[2] 97.690 0.067 1.959 93.691 96.526 97.689 98.852 101.512 861 1.010
## sigma 5.486 0.013 0.423 4.865 5.196 5.416 5.686 6.604 992 1.000
## sigma_age 1.241 0.021 0.578 0.387 0.839 1.131 1.527 2.663 732 1.004
## sigma_site 3.586 0.037 1.238 1.917 2.733 3.341 4.156 6.634 1143 1.000
## sigma_tot 6.767 0.029 0.863 5.614 6.172 6.599 7.158 8.923 891 1.000
## pi[1] 0.676 0.004 0.123 0.387 0.602 0.696 0.766 0.862 1229 1.000
## pi[2] 0.040 0.001 0.037 0.003 0.016 0.029 0.052 0.139 917 1.003
## pi[3] 0.284 0.004 0.127 0.099 0.190 0.261 0.355 0.592 1233 1.000
## difference 8.994 0.037 1.142 5.923 8.553 9.228 9.725 10.518 969 1.000
##
## Samples were drawn using NUTS(diag_e) at Sat Jun 5 07:44:02 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
Comparison of the estimation of the log-likelihood between the baseline and the mix models (using WAIC):
# WAIC
loo::loo_compare(loo::waic(rstan::extract(fit.baseline.model, "log_lik")$log_lik),
loo::waic(rstan::extract(fit.mixTarget.model, "log_lik")$log_lik))
## Warning:
## 1 (0.1%) p_waic estimates greater than 0.4. We recommend trying loo instead.
## Warning:
## 1 (0.1%) p_waic estimates greater than 0.4. We recommend trying loo instead.
## elpd_diff se_diff
## model2 0.0 0.0
## model1 -1.8 2.7
# LOO-CV
loo.mixTarget.model <- loo(fit.mixTarget.model)
loo::loo_compare(loo.baseline.model,loo.mixTarget.model)
## elpd_diff se_diff
## model2 0.0 0.0
## model1 -1.8 2.7
Importance of each factor (age, site, others) in the total variance:
summary_anova <- as.data.frame(
do.call('rbind', lapply(c("sigma", "sigma_age", "sigma_site","sigma_tot"),
function(param) {
get_summary(as.numeric(rstan::extract(fit.mixTarget.model, param)[[1]]))
}
)
)
)
names(summary_anova) <- c("mean", "se", "lower", "upper")
summary_anova$component <- c("residual", "age", "site","total")
summary_anova %>%
mutate(component = factor(component, levels = c("residual", "age", "site","total")[order(mean)])) %>%
ggplot(aes(x = component, y = mean)) +
geom_linerange(aes(x = component, ymin = lower, ymax = upper)) +
geom_point(size = 2) +
ylab("Estimate") + xlab("Source of variation") +
coord_flip() +
theme_bw()
We are going to include the cultivars (100 cultivars) as predictors for estimating the latent state \(z_{ijk}\).
/*--------------------- Functions ------------------------*/
functions {
real TwoGaussianMixture_lpdf(real y, real prob, vector location, real scale) {
real log_pdf[2];
log_pdf[1] = log1m(prob) + normal_lpdf(y| location[1], scale);
log_pdf[2] = log(prob) + normal_lpdf(y| location[2], scale);
return log_sum_exp(log_pdf);
}
real TwoGaussianMixture_rng(real prob, vector location, real scale) {
int z;
z = bernoulli_rng(prob);
return z ? normal_rng(location[2], scale) : normal_rng(location[1], scale);
}
}
/*----------------------- Data --------------------------*/
data {
int<lower = 1> n_obs; // Total number of observations
int<lower = 1> n_age; // Number of different age classes
int<lower = 1> n_site; // Number of different sites
int<lower = 1> n_cultivar; // Number of different cultivars
vector[n_obs] FLOWERING; // Response variable (flowering dates)
int<lower = 1, upper = n_age> AGE[n_obs]; // Age variable
int<lower = 1, upper = n_site> SITE[n_obs]; // Site variable
int<lower = 1, upper = n_cultivar> CULTIVAR[n_obs]; // Cultivar variable
real prior_location_beta0;
real<lower = 0.0> prior_scale_beta0;
real prior_location_diff;
real<lower = 0.0> prior_scale_diff;
real prior_location_eta0;
real<lower = 0.0> prior_scale_eta0;
}
/*----------------------- Parameters --------------------------*/
parameters {
simplex[3] pi; // unit complex specifying that the sum of its elements equal to one.
real beta0; // global intercept
real<lower = 0.0> sigma_tot; // Total standard deviation
vector[n_age] alpha; // Age intercepts
vector[n_site] delta; // Site intercepts
real<lower = 0> difference; // difference between beta_1 and beta_2
real<lower = 0.0> sigma_cultivar;
real eta0;
vector[n_cultivar] eta;
}
/*------------------- Transformed Parameters --------------------*/
transformed parameters {
real sigma; // Residual standard deviation
real sigma_age; // Standard deviation of the age intercepts
real sigma_site; // Standard deviation of the site intercepts
vector[n_obs] mu[2]; // linear predictor
vector[n_obs] p; // proba (early or late flowering)
vector[2] beta;
beta[1] = prior_location_beta0 + beta0 * prior_scale_beta0;
beta[2] = beta[1] + difference;
sigma = sqrt(pi[1]) * sigma_tot;
sigma_age = sqrt(pi[2]) * sigma_tot;
sigma_site = sqrt(pi[3]) * sigma_tot;
mu[1] = beta[1] + alpha[AGE] + delta[SITE];
mu[2] = beta[2] + alpha[AGE] + delta[SITE];
p = inv_logit(rep_vector(eta0, n_obs) + eta[CULTIVAR]);
}
/*----------------------- Model --------------------------*/
model {
// Priors
sigma_tot ~ student_t(3, 0.0, 1.0);
sigma_cultivar ~ student_t(3, 0.0, 1.0);
beta0 ~ normal(0.0, 1.0);
difference ~ normal(prior_location_diff, prior_scale_diff);
alpha ~ normal(0.0, sigma_age);
delta ~ normal(0.0, sigma_site);
eta0 ~ normal(prior_location_eta0, prior_scale_eta0);
eta ~ normal(0.0, sigma_cultivar);
// Likelihood
for(i in 1:n_obs) {
target += log_sum_exp(log1m(p[i]) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma),
log(p[i]) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
}
}
/*----------------- Generated Quantities ------------------*/
generated quantities {
vector[n_obs] log_lik; // Log-likelihood
vector[n_obs] y_rep; // posterior predictive check
for(i in 1:n_obs) {
log_lik[i] = log_sum_exp(log1m(p[i]) + normal_lpdf(FLOWERING[i] | mu[1, i], sigma),
log(p[i]) + normal_lpdf(FLOWERING[i] | mu[2, i], sigma));
y_rep[i] = TwoGaussianMixture_rng(p[i], to_vector(mu[1:2, i]), sigma);
}
}
Input data:
listMixCultivar.stan = list(n_obs = nrow(dataSakura),
n_age = length(unique(dataSakura$Age)),
n_site = length(unique(dataSakura$Site)),
n_cultivar = length(unique(dataSakura$Cultivar)),
FLOWERING = dataSakura$Flowering,
AGE = dataSakura$Age,
SITE = as.numeric(factor(dataSakura$Site, levels = unique(dataSakura$Site))),
CULTIVAR = as.numeric(factor(dataSakura$Cultivar, levels = unique(dataSakura$Cultivar))),
prior_location_beta0 = mean(dataSakura$Flowering),
prior_scale_beta0 = 10,
prior_location_diff = 7,
prior_scale_diff = 3,
prior_location_eta0 = 0.0,
prior_scale_eta0 = 1.5)
Sampling:
fit.mixCultivar.model <- sampling(mixCultivar.code,
data = listMixCultivar.stan,
pars = c("eta0", "eta", "sigma_cultivar",
"beta", "alpha", "delta",
"sigma", "sigma_age", "sigma_site", "sigma_tot",
"pi", "difference", "log_lik"),
save_warmup = F,
iter = 2000,
chains = 4,
cores = 4,
thin=1)
## Warning: There were 1 chains where the estimated Bayesian Fraction of Missing Information was low. See
## http://mc-stan.org/misc/warnings.html#bfmi-low
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#bulk-ess
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
Checking parameter convergence:
stan_rhat(fit.mixCultivar.model)
Parameter estimations:
print(fit.mixCultivar.model, digits = 3, pars = c("eta0", "sigma_cultivar", "beta", "sigma", "sigma_age", "sigma_site", "sigma_tot", "pi", "difference"))
## Inference for Stan model: fcabd7671bcaec1ff44fa109a42d38f4.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## eta0 -0.212 0.012 0.439 -1.086 -0.495 -0.211 0.066 0.649 1326 1.003
## sigma_cultivar 2.693 0.053 0.705 1.714 2.225 2.575 3.014 4.449 178 1.024
## beta[1] 88.213 0.060 1.666 84.906 87.182 88.176 89.199 91.625 761 1.003
## beta[2] 97.633 0.062 1.677 94.295 96.630 97.630 98.628 101.020 725 1.005
## sigma 5.292 0.011 0.235 4.860 5.127 5.282 5.441 5.783 483 1.010
## sigma_age 1.583 0.017 0.585 0.735 1.166 1.485 1.899 2.991 1141 1.005
## sigma_site 3.395 0.035 1.125 1.844 2.586 3.185 3.961 6.070 1045 1.001
## sigma_tot 6.571 0.025 0.721 5.607 6.081 6.425 6.908 8.332 829 1.004
## pi[1] 0.666 0.004 0.115 0.402 0.597 0.681 0.751 0.847 1057 1.002
## pi[2] 0.065 0.001 0.046 0.013 0.032 0.053 0.083 0.186 1133 1.004
## pi[3] 0.269 0.004 0.119 0.101 0.180 0.248 0.338 0.545 1103 1.001
## difference 9.420 0.024 0.563 8.224 9.062 9.449 9.802 10.442 545 1.010
##
## Samples were drawn using NUTS(diag_e) at Sat Jun 5 07:45:20 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
Comparison of the estimation of the log-likelihood between the two mix models (w or w/o cultivar variables):
# LOO-CV
loo.mixCultivar.model <- loo(fit.mixCultivar.model)
## Warning: Some Pareto k diagnostic values are too high. See help('pareto-k-diagnostic') for details.
loo::loo_compare(x=list(loo.baseline.model,loo.mixTarget.model,loo.mixCultivar.model))
## elpd_diff se_diff
## model3 0.0 0.0
## model2 -91.5 11.7
## model1 -93.3 12.3
Finally, we can display the correlation between cultivars and the probability of late flowering:
proba_late <- plogis(matrix(rep(rstan::extract(fit.mixCultivar.model, "eta0")$eta0, each = length(unique(dataSakura$Cultivar))), byrow = FALSE, ncol = length(unique(dataSakura$Cultivar))) + rstan::extract(fit.mixCultivar.model, "eta")$eta)
data.frame(id = unique(dataSakura$Cultivar),
proba = apply(proba_late, 2, mean),
lower = apply(proba_late, 2, lower),
upper = apply(proba_late, 2, upper)
) %>%
mutate(id = factor(id, levels = id[order(proba)])) %>%
ggplot(aes(x = id, y = proba)) +
geom_linerange(aes(x = id, ymin = lower, ymax = upper)) +
geom_point() +
xlab("cultivar") + ylab("Pr(Late Flowering)") +
coord_flip() +
theme(axis.text.y = element_text(size = 6))
By curiosity, we would like to see whether not using the simplex changes the estimates of \(\sigma\), \(\sigma_{age}\) and \(\sigma_{site}\). So, we redo the baseline model but without the simplex.
baseline.model.nosimplex <- stan_model("BaselineModelCode_NoSimplex.stan")
data {
int<lower = 1> n_obs; // Total number of observations
int<lower = 1> n_age; // Number of different age classes
int<lower = 1> n_site; // Number of different sites
vector[n_obs] FLOWERING; // Response variable (flowering dates)
int<lower = 1, upper = n_age> AGE[n_obs]; // Age variable
int<lower = 1, upper = n_site> SITE[n_obs]; // Site variable
real prior_location_beta0;
real<lower = 0.0> prior_scale_beta0;
}
parameters {
real beta0; // global intercept
vector[n_age] alpha; // Age intercepts
vector[n_site] delta; // Site intercepts
real sigma; // Residual standard deviation
real sigma_age; // Standard deviation of the age intercepts
real sigma_site; // Standard deviation of the site intercepts
}
transformed parameters {
vector[n_obs] mu; // linear predictor
mu = rep_vector(beta0, n_obs) + alpha[AGE] + delta[SITE];
}
model {
// Priors
beta0 ~ normal(prior_location_beta0, prior_scale_beta0); // Prior of the global intercept
alpha ~ normal(0.0, sigma_age); // Prior of the age intercepts
delta ~ normal(0.0, sigma_site); // Prior of the site intercepts
sigma ~ exponential(1);
sigma_age ~ exponential(1);
sigma_site ~ exponential(1);
// Likelihood
FLOWERING ~ normal(mu, sigma);
}
generated quantities {
vector[n_obs] log_lik; // Log-likelihood
vector[n_obs] y_rep; // posterior predictive check
for(i in 1:n_obs) {
log_lik[i] = normal_lpdf(FLOWERING[i]| mu[i], sigma);
y_rep[i] = normal_rng(mu[i], sigma);
}
}
fit.baseline.model.nosimplex <- sampling(baseline.model.nosimplex,
data = list.baseline.model,
pars = c("beta0", "alpha", "delta",
"sigma", "sigma_age", "sigma_site",
"y_rep", "log_lik"),
save_warmup = F,
iter = 2000,
chains = 4, cores = 4,thin=1)
## Warning: There were 36 divergent transitions after warmup. Increasing adapt_delta above 0.8 may help. See
## http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
## Warning: Examine the pairs() plot to diagnose sampling problems
## Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
## Running the chains for more iterations may help. See
## http://mc-stan.org/misc/warnings.html#tail-ess
print(fit.baseline.model.nosimplex,
digits = 3, pars = c("beta0", "sigma", "sigma_age", "sigma_site"))
## Inference for Stan model: 19b46d68dfd80e5d0f57ac5fbd3de601.
## 4 chains, each with iter=2000; warmup=1000; thin=1;
## post-warmup draws per chain=1000, total post-warmup draws=4000.
##
## mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
## beta0 92.785 0.055 1.471 89.983 91.868 92.801 93.715 95.675 717 1.007
## sigma 7.089 0.003 0.167 6.770 6.973 7.089 7.198 7.421 3253 1.002
## sigma_age 1.029 0.021 0.499 0.267 0.674 0.957 1.311 2.123 542 1.001
## sigma_site 2.933 0.022 0.859 1.695 2.322 2.803 3.369 5.022 1466 1.000
##
## Samples were drawn using NUTS(diag_e) at Sat Jun 5 07:46:00 2021.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at
## convergence, Rhat=1).
There are more warnings in this model, compared to the baseline model with a simplex.
list(baseline.model=fit.baseline.model,baseline.model.nosimplex=fit.baseline.model.nosimplex) %>%
mclapply(function(x) {
broom.mixed::tidyMCMC(x,pars=c("sigma","sigma_site","sigma_age"),
droppars = NULL, estimate.method = "median",
ess = F, rhat = F,
conf.int = T,conf.level = 0.95)}) %>%
bind_rows(.id="model") %>%
ggplot(aes(x = term, y = estimate,ymin = conf.low, ymax = conf.high,color=model)) +
geom_pointinterval(position = position_dodge(width = .8),point_size=5,alpha=0.6,size=8) +
xlab("") +
ylab("Standard deviation estimates") +
labs(color = "Models") +
theme(axis.text = element_text(size=20),
panel.grid.minor.x=element_blank(),
panel.grid.major.x=element_blank())
# Baseline model with the simplex
np <- nuts_params(fit.baseline.model)
mcmc_pairs(as.array(fit.baseline.model),
np = np,
pars = c("sigma","sigma_site","sigma_age"),
off_diag_args = list(size = 1, alpha = 1/3),
np_style = pairs_style_np(div_size=1, div_shape = 19),
max_treedepth = 10)
# Baseline model without the simplex
np <- nuts_params(fit.baseline.model.nosimplex)
mcmc_pairs(as.array(fit.baseline.model.nosimplex),
np = np,
pars = c("sigma","sigma_site","sigma_age"),
off_diag_args = list(size = 1, alpha = 1/3),
np_style = pairs_style_np(div_size=1, div_shape = 19),
max_treedepth = 10)
Wenden, Bénédicte, José Antonio Campoy, Julien Lecourt, Gregorio López Ortega, Michael Blanke, Sanja Radičević, Elisabeth Schüller, et al. 2016. “A Collection of European Sweet Cherry Phenology Data for Assessing Climate Change.” Scientific Data 3 (1). Nature Publishing Group: 1–10.